Skip to content

Conversation

skc7
Copy link
Contributor

@skc7 skc7 commented May 19, 2025

This PR introduces a new pass "lower-workdistribute"
Fortran array statements are lowered to fir as fir.do_loop unordered. "lower-workdistribute" pass works mainly on identifying "fir.do_loop unordered" that is nested in target{teams{workdistribute{fir.do_loop unordered}}} and lowers it to target{teams{parallel{wsloop{loop_nest}}}}. It hoists all the other ops outside target region. Relaces heap allocation on target with omp.target_allocmem and deallocation with omp.target_freemem from host. Also replaces runtime function "Assign" with omp.target_memcpy from host.

This pass implements following rewrites and optimisations:

  • FissionWorkdistribute: finds the parallelizable ops within teams {workdistribute} region and moves them to their own teams{workdistribute} region.
  • WorkdistributeRuntimeCallLower: finds the FortranAAssign calls nested in teams {workdistribute{}} and lowers it to unordered do loop if src is scalar and dest is array. Other runtime calls are not handled currently.
  • WorkdistributeDoLower: finds the fir.do_loop unoredered nested in teams {workdistribute{fir.do_loop unoredered}} and lowers it to teams {parallel { distribute {wsloop {loop_nest}}}}.
  • TeamsWorkdistributeToSingle: hoists all the ops inside teams {workdistribute{}} before teams op.

The work in this PR is C-P and updated from @ivanradanov commits from coexecute implementation:
flang_workdistribute_iwomp_2024

Paper related to this work by @ivanradanov "Automatic Parallelization and OpenMP Offloadingof Fortran Array Notation"

@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from 6e8010d to df65bd5 Compare May 19, 2025 14:18
@skc7 skc7 requested a review from mjklemm May 20, 2025 07:12
@skc7 skc7 marked this pull request as draft May 20, 2025 07:16
@skc7 skc7 self-assigned this May 20, 2025
Copy link

github-actions bot commented Jun 5, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from bcf51ad to 8669a40 Compare June 23, 2025 08:51
@skc7 skc7 changed the title [WIP] Implement workdistribute construct [flang] Implement workdistribute construct lowering Jun 27, 2025
@ivanradanov
Copy link
Contributor

I am not sure the using rewrite patterns is suitable here - mainly because we have a list of transformations we know we need to do in what order so there is no need to involve patterns which add additional overhead and complexity (guarding against matching again etc, making sure the order is correct). That's how the LowerWorkshare pass is implemented.

@skc7
Copy link
Contributor Author

skc7 commented Jul 4, 2025

I am not sure the using rewrite patterns is suitable here - mainly because we have a list of transformations we know we need to do in what order so there is no need to involve patterns which add additional overhead and complexity (guarding against matching again etc, making sure the order is correct). That's how the LowerWorkshare pass is implemented.

Thanks @ivanradanov for feedback. Have updated PR, removing the rewrite patterns.

@skc7 skc7 marked this pull request as ready for review July 4, 2025 09:10
@skc7 skc7 requested review from ivanradanov, ergawy and kparzysz July 16, 2025 16:26
@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from b6ca26a to d253390 Compare July 25, 2025 11:47
@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp labels Jul 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Author: Chaitanya (skc7)

Changes

Pre-requiste PRs:
[flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops .
[flang-rt] Add Assign_omp RT call.
[flang] Add support for workdistribute construct in flang frontend

This PR introduces a new pass "lower-workdistribute" which identifies parallel ops inside workdistribute region and moves them to new omp. target region.

This pass implements following rewrites and optimisations:
FissionWorkdistribute, WorkdistributeDoLower and TeamsWorkdistributeToSingle.

After the pattern match and rewrite, omp.target is moved under omp.target_data region and then only moves the parallelize ops to new omp.target region and moves all other ops to host.

The work in this PR is C-P and updated from @ivanradanov commits from coexecute implementation:
flang_workdistribute_iwomp_2024


Patch is 58.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140523.diff

10 Files Affected:

  • (modified) flang/include/flang/Optimizer/OpenMP/Passes.td (+4)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp (+901)
  • (modified) flang/lib/Optimizer/Passes/Pipelines.cpp (+3-1)
  • (modified) flang/test/Fir/basic-program.fir (+1)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir (+33)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir (+112)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir (+71)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir (+32)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+3)
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..0885efc716db4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,901 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.execute region.
+static bool shouldParallelize(Operation *op) {
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  if (isRuntimeCall(op)) {
+    return true;
+  }
+  // We cannot parallise anything else
+  return false;
+}
+
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return false;
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+///   omp.parallel {
+///     omp.distribute {
+///     omp.wsloop {
+///       omp.loop_nest
+///         ...
+///       }
+///     }
+///   }
+/// }
+
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    // rewriter.erase(terminatorOp);
+    terminatorOp->erase();
+  }
+  return;
+}
+
+static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+  auto wdLoc = workdistribute->getLoc();
+  if (doLoop && shouldParallelize(doLoop)) {
+    assert(doLoop.getReduceOperands().empty());
+    genParallelOp(wdLoc, rewriter, true);
+    genDistributeOp(wdLoc, rewriter, true);
+    mlir::omp::LoopNestOperands loopNestClauseOps;
+    genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+    genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+    workdistribute.erase();
+    return true;
+  }
+  return false;
+}
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+  auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+  if (!workdistributeOp)
+    return false;
+  // Get the block containing teamsOp (the parent block).
+  Block *parentBlock = teamsOp->getBlock();
+  Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  auto insertPoint = Block::iterator(teamsOp);
+  // Get the range of operations to move (excluding the terminator).
+  auto workdistributeBegin = workdistributeBlock.begin();
+  auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+  // Move the operations from workdistribute block to before teamsOp.
+  parentBlock->getOperations().splice(insertPoint,
+                                      workdistributeBlock.getOperations(),
+                                      workdistributeBegin, workdistributeEnd);
+  // Erase the now-empty workdistributeOp.
+  workdistributeOp.erase();
+  Block &teamsBlock = *teamsOp.getRegion().begin();
+  // Check if only the terminator remains and erase teams op.
+  if (teamsBlock.getOperations().size() == 1 &&
+      teamsBlock.getTerminator() != nullptr) {
+    teamsOp.erase();
+  }
+  return true;
+}
+
+struct SplitTargetResult {
+  omp::TargetOp targetOp;
+  omp::TargetDataOp dataOp;
+};
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+
+  SmallVector<omp::MapInfoOp> mapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    mapInfos.push_back(mapInfo);
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Value> innerMapInfos;
+  SmallVector<Value> outerMapInfos;
+
+  for (auto mapInfo : mapInfos) {
+    auto originalMapType =
+        (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+    auto originalCaptureType = mapInfo.getMapCaptureType();
+    llvm::omp::OpenMPOffloadMappingFlags newMapType;
+    mlir::omp::VariableCaptureKind newCaptureType;
+
+    if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+      newMapType = originalMapType;
+      newCaptureType = originalCaptureType;
+    } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+      newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+      newCaptureType = originalCaptureType;
+      outerMapInfos.push_back(mapInfo);
+    } else {
+      llvm_unreachable("Unhandled case");
+    }
+    auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+    innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+        rewriter.getIntegerType(64, false),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            newMapType)));
+    innerMapInfo.setMapCaptureType(newCaptureType);
+    innerMapInfos.push_back(innerMapInfo.getResult());
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+
+  auto newTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+                              newTargetOp.getRegion().begin());
+
+  rewriter.replaceOp(targetOp, newTargetOp);
+  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  } else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+  };
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+};
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+  return false;
+}
+
+...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2025

@llvm/pr-subscribers-mlir

Author: Chaitanya (skc7)

Changes

Pre-requiste PRs:
[flang] Introduce omp.target_allocmem and omp.target_freemem omp dialect ops .
[flang-rt] Add Assign_omp RT call.
[flang] Add support for workdistribute construct in flang frontend

This PR introduces a new pass "lower-workdistribute" which identifies parallel ops inside workdistribute region and moves them to new omp. target region.

This pass implements following rewrites and optimisations:
FissionWorkdistribute, WorkdistributeDoLower and TeamsWorkdistributeToSingle.

After the pattern match and rewrite, omp.target is moved under omp.target_data region and then only moves the parallelize ops to new omp.target region and moves all other ops to host.

The work in this PR is C-P and updated from @ivanradanov commits from coexecute implementation:
flang_workdistribute_iwomp_2024


Patch is 58.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140523.diff

10 Files Affected:

  • (modified) flang/include/flang/Optimizer/OpenMP/Passes.td (+4)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+1)
  • (added) flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp (+901)
  • (modified) flang/lib/Optimizer/Passes/Pipelines.cpp (+3-1)
  • (modified) flang/test/Fir/basic-program.fir (+1)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir (+33)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-fission-target.mlir (+112)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir (+71)
  • (added) flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir (+32)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+3)
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 704faf0ccd856..743b6d381ed42 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
+  let summary = "Lower workdistribute construct";
+}
+
 def GenericLoopConversionPass
     : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
   let summary = "Converts OpenMP generic `omp.loop` to semantically "
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index e31543328a9f9..cd746834741f9 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
+  LowerWorkdistribute.cpp
   LowerWorkshare.cpp
   LowerNontemporal.cpp
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
new file mode 100644
index 0000000000000..0885efc716db4
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -0,0 +1,901 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and optimisations of omp.workdistribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Utils.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/Utils/IndexingUtils.h>
+#include <mlir/IR/BlockSupport.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/Diagnostics.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <optional>
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workdistribute"
+
+using namespace mlir;
+
+namespace {
+
+static bool isRuntimeCall(Operation *op) {
+  if (auto callOp = dyn_cast<fir::CallOp>(op)) {
+    auto callee = callOp.getCallee();
+    if (!callee)
+      return false;
+    auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
+    if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
+      return true;
+  }
+  return false;
+}
+
+/// This is the single source of truth about whether we should parallelize an
+/// operation nested in an omp.execute region.
+static bool shouldParallelize(Operation *op) {
+  if (llvm::any_of(op->getResults(),
+                   [](OpResult v) -> bool { return !v.use_empty(); }))
+    return false;
+  // We will parallelize unordered loops - these come from array syntax
+  if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
+    auto unordered = loop.getUnordered();
+    if (!unordered)
+      return false;
+    return *unordered;
+  }
+  if (isRuntimeCall(op)) {
+    return true;
+  }
+  // We cannot parallise anything else
+  return false;
+}
+
+template <typename T>
+static T getPerfectlyNested(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return nullptr;
+  auto &region = op->getRegion(0);
+  if (region.getBlocks().size() != 1)
+    return nullptr;
+  auto *block = &region.front();
+  auto *firstOp = &block->front();
+  if (auto nested = dyn_cast<T>(firstOp))
+    if (firstOp->getNextNode() == block->getTerminator())
+      return nested;
+  return nullptr;
+}
+
+/// If B() and D() are parallelizable,
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///     C()
+///     D()
+///     E()
+///   }
+/// }
+///
+/// becomes
+///
+/// A()
+/// omp.teams {
+///   omp.workdistribute {
+///     B()
+///   }
+/// }
+/// C()
+/// omp.teams {
+///   omp.workdistribute {
+///     D()
+///   }
+/// }
+/// E()
+
+static bool FissionWorkdistribute(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto loc = workdistribute->getLoc();
+  auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
+  if (!teams) {
+    emitError(loc, "workdistribute not nested in teams\n");
+    return false;
+  }
+  if (workdistribute.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "workdistribute with multiple blocks\n");
+    return false;
+  }
+  if (teams.getRegion().getBlocks().size() != 1) {
+    emitError(loc, "teams with multiple blocks\n");
+    return false;
+  }
+
+  auto *teamsBlock = &teams.getRegion().front();
+  bool changed = false;
+  // Move the ops inside teams and before workdistribute outside.
+  IRMapping irMapping;
+  llvm::SmallVector<Operation *> teamsHoisted;
+  for (auto &op : teams.getOps()) {
+    if (&op == workdistribute) {
+      break;
+    }
+    if (shouldParallelize(&op)) {
+      emitError(loc, "teams has parallelize ops before first workdistribute\n");
+      return false;
+    } else {
+      rewriter.setInsertionPoint(teams);
+      rewriter.clone(op, irMapping);
+      teamsHoisted.push_back(&op);
+      changed = true;
+    }
+  }
+  for (auto *op : llvm::reverse(teamsHoisted)) {
+    op->replaceAllUsesWith(irMapping.lookup(op));
+    op->erase();
+  }
+
+  // While we have unhandled operations in the original workdistribute
+  auto *workdistributeBlock = &workdistribute.getRegion().front();
+  auto *terminator = workdistributeBlock->getTerminator();
+  while (&workdistributeBlock->front() != terminator) {
+    rewriter.setInsertionPoint(teams);
+    IRMapping mapping;
+    llvm::SmallVector<Operation *> hoisted;
+    Operation *parallelize = nullptr;
+    for (auto &op : workdistribute.getOps()) {
+      if (&op == terminator) {
+        break;
+      }
+      if (shouldParallelize(&op)) {
+        parallelize = &op;
+        break;
+      } else {
+        rewriter.clone(op, mapping);
+        hoisted.push_back(&op);
+        changed = true;
+      }
+    }
+
+    for (auto *op : llvm::reverse(hoisted)) {
+      op->replaceAllUsesWith(mapping.lookup(op));
+      op->erase();
+    }
+
+    if (parallelize && hoisted.empty() &&
+        parallelize->getNextNode() == terminator)
+      break;
+    if (parallelize) {
+      auto newTeams = rewriter.cloneWithoutRegions(teams);
+      auto *newTeamsBlock = rewriter.createBlock(
+          &newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
+      for (auto arg : teamsBlock->getArguments())
+        newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
+      auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
+      rewriter.create<omp::TerminatorOp>(loc);
+      rewriter.createBlock(&newWorkdistribute.getRegion(),
+                           newWorkdistribute.getRegion().begin(), {}, {});
+      auto *cloned = rewriter.clone(*parallelize);
+      parallelize->replaceAllUsesWith(cloned);
+      parallelize->erase();
+      rewriter.create<omp::TerminatorOp>(loc);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+/// If fir.do_loop is present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     fir.do_loop unoredered {
+///       ...
+///     }
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// omp.teams {
+///   omp.parallel {
+///     omp.distribute {
+///     omp.wsloop {
+///       omp.loop_nest
+///         ...
+///       }
+///     }
+///   }
+/// }
+
+static void genParallelOp(Location loc, OpBuilder &rewriter, bool composite) {
+  auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
+  parallelOp.setComposite(composite);
+  rewriter.createBlock(&parallelOp.getRegion());
+  rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+  return;
+}
+
+static void genDistributeOp(Location loc, OpBuilder &rewriter, bool composite) {
+  mlir::omp::DistributeOperands distributeClauseOps;
+  auto distributeOp =
+      rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
+  distributeOp.setComposite(composite);
+  auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
+  rewriter.setInsertionPointToStart(distributeBlock);
+  return;
+}
+
+static void
+genLoopNestClauseOps(OpBuilder &rewriter, fir::DoLoopOp loop,
+                     mlir::omp::LoopNestOperands &loopNestClauseOps) {
+  assert(loopNestClauseOps.loopLowerBounds.empty() &&
+         "Loop nest bounds were already emitted!");
+  loopNestClauseOps.loopLowerBounds.push_back(loop.getLowerBound());
+  loopNestClauseOps.loopUpperBounds.push_back(loop.getUpperBound());
+  loopNestClauseOps.loopSteps.push_back(loop.getStep());
+  loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
+}
+
+static void genWsLoopOp(mlir::OpBuilder &rewriter, fir::DoLoopOp doLoop,
+                        const mlir::omp::LoopNestOperands &clauseOps,
+                        bool composite) {
+
+  auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
+  wsloopOp.setComposite(composite);
+  rewriter.createBlock(&wsloopOp.getRegion());
+
+  auto loopNestOp =
+      rewriter.create<mlir::omp::LoopNestOp>(doLoop.getLoc(), clauseOps);
+
+  // Clone the loop's body inside the loop nest construct using the
+  // mapped values.
+  rewriter.cloneRegionBefore(doLoop.getRegion(), loopNestOp.getRegion(),
+                             loopNestOp.getRegion().begin());
+  Block *clonedBlock = &loopNestOp.getRegion().back();
+  mlir::Operation *terminatorOp = clonedBlock->getTerminator();
+
+  // Erase fir.result op of do loop and create yield op.
+  if (auto resultOp = dyn_cast<fir::ResultOp>(terminatorOp)) {
+    rewriter.setInsertionPoint(terminatorOp);
+    rewriter.create<mlir::omp::YieldOp>(doLoop->getLoc());
+    // rewriter.erase(terminatorOp);
+    terminatorOp->erase();
+  }
+  return;
+}
+
+static bool WorkdistributeDoLower(omp::WorkdistributeOp workdistribute) {
+  OpBuilder rewriter(workdistribute);
+  auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
+  auto wdLoc = workdistribute->getLoc();
+  if (doLoop && shouldParallelize(doLoop)) {
+    assert(doLoop.getReduceOperands().empty());
+    genParallelOp(wdLoc, rewriter, true);
+    genDistributeOp(wdLoc, rewriter, true);
+    mlir::omp::LoopNestOperands loopNestClauseOps;
+    genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
+    genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
+    workdistribute.erase();
+    return true;
+  }
+  return false;
+}
+
+/// If A() and B () are present inside teams workdistribute
+///
+/// omp.teams {
+///   omp.workdistribute {
+///     A()
+///     B()
+///   }
+/// }
+///
+/// Then, its lowered to
+///
+/// A()
+/// B()
+///
+
+static bool TeamsWorkdistributeToSingleOp(omp::TeamsOp teamsOp) {
+  auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
+  if (!workdistributeOp)
+    return false;
+  // Get the block containing teamsOp (the parent block).
+  Block *parentBlock = teamsOp->getBlock();
+  Block &workdistributeBlock = *workdistributeOp.getRegion().begin();
+  auto insertPoint = Block::iterator(teamsOp);
+  // Get the range of operations to move (excluding the terminator).
+  auto workdistributeBegin = workdistributeBlock.begin();
+  auto workdistributeEnd = workdistributeBlock.getTerminator()->getIterator();
+  // Move the operations from workdistribute block to before teamsOp.
+  parentBlock->getOperations().splice(insertPoint,
+                                      workdistributeBlock.getOperations(),
+                                      workdistributeBegin, workdistributeEnd);
+  // Erase the now-empty workdistributeOp.
+  workdistributeOp.erase();
+  Block &teamsBlock = *teamsOp.getRegion().begin();
+  // Check if only the terminator remains and erase teams op.
+  if (teamsBlock.getOperations().size() == 1 &&
+      teamsBlock.getTerminator() != nullptr) {
+    teamsOp.erase();
+  }
+  return true;
+}
+
+struct SplitTargetResult {
+  omp::TargetOp targetOp;
+  omp::TargetDataOp dataOp;
+};
+
+/// If multiple workdistribute are nested in a target regions, we will need to
+/// split the target region, but we want to preserve the data semantics of the
+/// original data region and avoid unnecessary data movement at each of the
+/// subkernels - we split the target region into a target_data{target}
+/// nest where only the outer one moves the data
+std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
+                                                 RewriterBase &rewriter) {
+  auto loc = targetOp->getLoc();
+  if (targetOp.getMapVars().empty()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << DEBUG_TYPE << " target region has no data maps\n");
+    return std::nullopt;
+  }
+
+  SmallVector<omp::MapInfoOp> mapInfos;
+  for (auto opr : targetOp.getMapVars()) {
+    auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
+    mapInfos.push_back(mapInfo);
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  SmallVector<Value> innerMapInfos;
+  SmallVector<Value> outerMapInfos;
+
+  for (auto mapInfo : mapInfos) {
+    auto originalMapType =
+        (llvm::omp::OpenMPOffloadMappingFlags)(mapInfo.getMapType());
+    auto originalCaptureType = mapInfo.getMapCaptureType();
+    llvm::omp::OpenMPOffloadMappingFlags newMapType;
+    mlir::omp::VariableCaptureKind newCaptureType;
+
+    if (originalCaptureType == mlir::omp::VariableCaptureKind::ByCopy) {
+      newMapType = originalMapType;
+      newCaptureType = originalCaptureType;
+    } else if (originalCaptureType == mlir::omp::VariableCaptureKind::ByRef) {
+      newMapType = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+      newCaptureType = originalCaptureType;
+      outerMapInfos.push_back(mapInfo);
+    } else {
+      llvm_unreachable("Unhandled case");
+    }
+    auto innerMapInfo = cast<omp::MapInfoOp>(rewriter.clone(*mapInfo));
+    innerMapInfo.setMapTypeAttr(rewriter.getIntegerAttr(
+        rewriter.getIntegerType(64, false),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            newMapType)));
+    innerMapInfo.setMapCaptureType(newCaptureType);
+    innerMapInfos.push_back(innerMapInfo.getResult());
+  }
+
+  rewriter.setInsertionPoint(targetOp);
+  auto device = targetOp.getDevice();
+  auto ifExpr = targetOp.getIfExpr();
+  auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
+  auto devicePtrVars = targetOp.getIsDevicePtrVars();
+  auto targetDataOp = rewriter.create<omp::TargetDataOp>(
+      loc, device, ifExpr, outerMapInfos, deviceAddrVars, devicePtrVars);
+  auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
+  rewriter.create<mlir::omp::TerminatorOp>(loc);
+  rewriter.setInsertionPointToStart(taregtDataBlock);
+
+  auto newTargetOp = rewriter.create<omp::TargetOp>(
+      targetOp.getLoc(), targetOp.getAllocateVars(),
+      targetOp.getAllocatorVars(), targetOp.getBareAttr(),
+      targetOp.getDependKindsAttr(), targetOp.getDependVars(),
+      targetOp.getDevice(), targetOp.getHasDeviceAddrVars(),
+      targetOp.getHostEvalVars(), targetOp.getIfExpr(),
+      targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(),
+      targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
+      innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
+      targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+  rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
+                              newTargetOp.getRegion().begin());
+
+  rewriter.replaceOp(targetOp, newTargetOp);
+  return SplitTargetResult{cast<omp::TargetOp>(newTargetOp), targetDataOp};
+}
+
+static std::optional<std::tuple<Operation *, bool, bool>>
+getNestedOpToIsolate(omp::TargetOp targetOp) {
+  if (targetOp.getRegion().empty())
+    return std::nullopt;
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto &op : *targetBlock) {
+    bool first = &op == &*targetBlock->begin();
+    bool last = op.getNextNode() == targetBlock->getTerminator();
+    if (first && last)
+      return std::nullopt;
+
+    if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
+      return {{&op, first, last}};
+  }
+  return std::nullopt;
+}
+
+struct TempOmpVar {
+  omp::MapInfoOp from, to;
+};
+
+static bool isPtr(Type ty) {
+  return isa<fir::ReferenceType>(ty) || isa<LLVM::LLVMPointerType>(ty);
+}
+
+static Type getPtrTypeForOmp(Type ty) {
+  if (isPtr(ty))
+    return LLVM::LLVMPointerType::get(ty.getContext());
+  else
+    return fir::LLVMPointerType::get(ty);
+}
+
+static TempOmpVar allocateTempOmpVar(Location loc, Type ty,
+                                     RewriterBase &rewriter) {
+  MLIRContext &ctx = *ty.getContext();
+  Value alloc;
+  Type allocType;
+  auto llvmPtrTy = LLVM::LLVMPointerType::get(&ctx);
+  if (isPtr(ty)) {
+    Type intTy = rewriter.getI32Type();
+    auto one = rewriter.create<LLVM::ConstantOp>(loc, intTy, 1);
+    allocType = llvmPtrTy;
+    alloc = rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, allocType, one);
+    allocType = intTy;
+  } else {
+    allocType = ty;
+    alloc = rewriter.create<fir::AllocaOp>(loc, allocType);
+  }
+  auto getMapInfo = [&](uint64_t mappingFlags, const char *name) {
+    return rewriter.create<omp::MapInfoOp>(
+        loc, alloc.getType(), alloc, TypeAttr::get(allocType),
+        rewriter.getIntegerAttr(rewriter.getIntegerType(64, /*isSigned=*/false),
+                                mappingFlags),
+        rewriter.getAttr<omp::VariableCaptureKindAttr>(
+            omp::VariableCaptureKind::ByRef),
+        /*varPtrPtr=*/Value{},
+        /*members=*/SmallVector<Value>{},
+        /*member_index=*/mlir::ArrayAttr{},
+        /*bounds=*/ValueRange(),
+        /*mapperId=*/mlir::FlatSymbolRefAttr(),
+        /*name=*/rewriter.getStringAttr(name), rewriter.getBoolAttr(false));
+  };
+  uint64_t mapFrom =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
+  uint64_t mapTo =
+      static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+          llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
+  auto mapInfoFrom = getMapInfo(mapFrom, "__flang_workdistribute_from");
+  auto mapInfoTo = getMapInfo(mapTo, "__flang_workdistribute_to");
+  return TempOmpVar{mapInfoFrom, mapInfoTo};
+};
+
+static bool usedOutsideSplit(Value v, Operation *split) {
+  if (!split)
+    return false;
+  auto targetOp = cast<omp::TargetOp>(split->getParentOp());
+  auto *targetBlock = &targetOp.getRegion().front();
+  for (auto *user : v.getUsers()) {
+    while (user->getBlock() != targetBlock) {
+      user = user->getParentOp();
+    }
+    if (!user->isBeforeInBlock(split))
+      return true;
+  }
+  return false;
+};
+
+static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
+  if (isa<fir::DeclareOp>(op))
+    return true;
+
+  llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+  MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!interface) {
+    return false;
+  }
+  interface.getEffects(effects);
+  if (effects.empty())
+    return true;
+  return false;
+}
+
+...
[truncated]

@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from 95ebe8c to 62a8dd7 Compare September 29, 2025 09:18
@skc7 skc7 changed the title [flang] Implement workdistribute construct lowering [Flang][OpenMP] Implement workdistribute construct lowering Sep 30, 2025
@skc7 skc7 force-pushed the skc7/flang_workdistribute branch from 8a0ae0f to d7e8d90 Compare October 6, 2025 04:48
Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. but I'd like someone with more code-gen experience to look at that, too.

// omp.target { teams { workdistribute { ... } } } is well formed
// and fails for function calls that don't have lowering implemented yet.
static bool
VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the capitalization intended? I thought the coding style is uncapitalized for functions (also applies to other functions in the file)

Suggested change
VerifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {
verifyTargetTeamsWorkdistribute(omp::WorkdistributeOp workdistribute) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for feedback. This was unintentional typo. Fixed for all the function names aswell.

if (isa<fir::DeclareOp>(op))
return true;

llvm::SmallVector<MemoryEffects::EffectInstance> effects;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reuse isMemoryEffectFree() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

namespace {

// The isRuntimeCall function is a utility designed to determine
// if a given operation is a call to a Fortran-specific runtime function.
Copy link
Contributor

@ivanradanov ivanradanov Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the comments here be doc comments? (/// instead of //) (also applies to other functions).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was unaware of this comment style. Fixed this for entire pass.

// }
// }
// }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

}
if (shouldParallelize(&op)) {
emitError(loc, "teams has parallelize ops before first workdistribute\n");
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an error that cannot be recovered from so maybe this function should return ErrorOr<bool> or something like that to correctly catch the error in the Pass and signalPassFailure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be an error. It's legal to have intervening code in teams before workdistribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per spec "The workdistribute construct must be a closely nested construct inside a teams construct"

AFAIU, no other parallelizable operations could be present between teams and workdistribute.

Copy link
Contributor

@kparzysz kparzysz Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closely nested construct means that there shouldn't be any construct nested between teams and workdistribute, not that there can't be any construct in teams other than distribute.

teams {
  parallel {              // wrong
    workdistribute {}
  }
}

teams {
  parallel {              // ok
  }
  workdistribute {
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per INTEL FORTRAN compiler, LINK

A WORKDISTRIBUTE construct must be closely nested in a TEAMS construct; there can be no other construct nested in the TEAMS construct prior to the WORKDISTRIBUTE construct.

OpenMP spec doesn't specifically mention about other constructs being present along with workdistribute inside teams.

@kparzysz So, is it up to vendor to decide on the restrictions for this construct?

CC: @mjklemm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could have sworn there was some wording in the standard that prevented your second // ok example but indeed I cannot find anything like that now.
I think this complicates things quite a lot

target teams {
  a = A() 
  workdistribute {
    B()
  }
  C(a)
}

needs to become something like this

%tmp_a = alloc <num_teams x i32>
target teams {
  %a = A() : i32
  store %a %tmp_a[%team_id]
}
target teams {
  workdistribute {
    B()
  }
}
target teams {
  %a = load %tmp_a[%team_id]
  C(%a)
}

Does that seem right?
That would also mean that %a cannot be used by B() because we would not have a single value for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a verification step to error if any "omp construct" or "parallelizable ops" are present before or after workdistribute inside teams region, with error message Lowering not implemented yet.

auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
if (!teams) {
emitError(loc, "workdistribute not nested in teams\n");
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the other ErrorOr<bool> comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for feedback on this. Used FailureOr<> in the pass to emit errors and fail.

return false;
} else {
workdistribute.emitError() << "Non-runtime fir.call lowering not "
"supported in workdistribute yet.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message says not supported yet, but I think this is something that should never happen if our pipeline is sound. (workdistribute should never have non-rt calls in it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it. We already detect any user-defined functions in semantics checks. Test : workdistribute01.f90

};

// moveToHost method clones all the ops from target region outside of it.
// It hoists runtime functions and replaces them with omp vesions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaces them with omp vesions

Typo and also is this description outdated? For assign we seem to codegen an assign implementation and not replace with an omp version of the runtime function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated description.
For Array to array assignments, __FortranAAsign is replaced with omp.target_memcpy.
For scalar assignment to array, workdistributeRuntimeCallLower handles lowering of __FortranAAsign .

Copy link
Contributor

@kparzysz kparzysz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't read the entirety of the code, but I have some comments so far.

return false;
return *unordered;
}
// We cannot parallise anything else.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: parallise

}
if (shouldParallelize(&op)) {
emitError(loc, "teams has parallelize ops before first workdistribute\n");
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be an error. It's legal to have intervening code in teams before workdistribute.

Comment on lines 205 to 207
omp::TargetOp targetOp;
// Get the target op parent of teams
targetOp = dyn_cast<omp::TargetOp>(teams->getParentOp());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused.

Also, teams doesn't have to be nested inside a target, it can be a top-level construct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this unused code. fissionWorkdistribute shouldn't be dependent on parent target op to teams. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point - I don't think we have any tests for

teams workdistribute

We only have

target teams workdistribute

We should probably test that. (and make sure it works)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extended tests to check teams workdistribute aswell in latest patch.

Comment on lines 186 to 198
// A()
// omp.teams {
// omp.workdistribute {
// B()
// }
// }
// C()
// omp.teams {
// omp.workdistribute {
// D()
// }
// }
// E()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if splitting teams is the right thing to do. Even then, it should be an optimization rather than a part of lowering. We should be able to generate code for workdistribute without splitting teams.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spoke with Michael and it seems that this is necessary for correctness. If that's the case, disregard my previous comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation of splitting teams is by following the paper "Automatic Parallelization and OpenMP Offloadingof Fortran Array Notation"

AFAIU, in OpenMP, every for loop that's intended to be parallelized, would be under its own teams construct like "teams parallel distribute for".
Single teams could have multiple do_loops like below and need to rethink strategy on how to handle such scenarios without splitting teams.

teams {
workdistribute {
fir.do_loop unordered {}
fir.do_loop unordered {}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary for correctness because there is no target-wide barrier construct in OpenMP (or on every GPU we want to target) and the only way to achieve that is (I think) to launch separate kernels.

}

// This is the single source of truth about whether we should parallelize an
// operation nested in an omp.workdistribute region.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that this parallelize here refers to the divide into units of work from the standard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in latest patch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants